/*
 * 本仿真文件中引入了一个STA和一个AP，两个结点间通过三个Channel进行连接，如下图所示：
 *       ---Channel 36, 20 MHz, 5 GHz---
 *       ---Channel 40, 20 MHz, 5 GHz---
 *       ---Channel 1, 20 MHz, 2.4 GHz---
 *   sta                                     ap
 * 本次仿真主要验证在不主动开启emlsr的情况下，ns-3如何处理多Link的问题。结果是在触发了
 * WiFi信标帧中的multi-Link字段而且两个设备在一个Channel上认证过后就在多个Channel上
 * 建立了链接。
 * 仿真结果还显示两个设备是交替的使用不同的Link进行数据的传输工作。这个方式和emlsr的方式
 * 很相似。
 *
 * 现在已经引入了一个仅使用了一个信道的sta结点。引入这个结点的时候出现了一些问题：
 * 对于WiFihelper，phyhelper和machelper可以新建
 * 但是channel一定要使用同一个，否则会导致sta和ap无法通信
 * 看来在ns-3底层，是使用channel来连接两个结点的
 *
 * 该仿真还有诸多的不足：比如说没有在多个链路上引入更多的sta这种情况，也没有在链路中引入
 * 更大的噪声，来检查WiFi 7在多链路上具有稳定的通信质量这个特性*/

#include "ns3/boolean.h"
#include "ns3/command-line.h"
#include "ns3/config.h"
#include "ns3/double.h"
#include "ns3/eht-phy.h"
#include "ns3/enum.h"
#include "ns3/frame-exchange-manager.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/log.h"
#include "ns3/mobility-helper.h"
#include "ns3/multi-model-spectrum-channel.h"
#include "ns3/on-off-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/packet-sink.h"
#include "ns3/rng-seed-manager.h"
#include "ns3/spectrum-wifi-helper.h"
#include "ns3/ssid.h"
#include "ns3/string.h"
#include "ns3/udp-client-server-helper.h"
#include "ns3/uinteger.h"
#include "ns3/wifi-acknowledgment.h"
#include "ns3/wifi-mac.h"
#include "ns3/wifi-net-device.h"
#include "ns3/yans-wifi-channel.h"
#include "ns3/yans-wifi-helper.h"
// #include "ns3/myTools.h"

#include <array>
#include <functional>
#include <numeric>

using namespace ns3;

NS_LOG_COMPONENT_DEFINE("EhtStaApLinkLink");

/**
 * \param udp true if UDP is used, false if TCP is used
 * \param serverApp a container of server applications
 * \param payloadSize the size in bytes of the packets
 * \return the bytes received by each server application
 */
std::vector<uint64_t>
StaApLinkLinkGetRxBytes(bool udp, const ApplicationContainer& serverApp, uint32_t payloadSize)
{
    std::vector<uint64_t> rxBytes(serverApp.GetN(), 0);
    if (udp)
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = payloadSize * DynamicCast<UdpServer>(serverApp.Get(i))->GetReceived();
        }
    }
    else
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = DynamicCast<PacketSink>(serverApp.Get(i))->GetTotalRx();
        }
    }
    return rxBytes;
};

int
main(int argc, char* argv[])
{
    double simulationTime = 3; // seconds
    double distance = 1.0;     // meters

    int mcs = 13;
    int gi = 3200;              // ns
    uint32_t payloadSize = 700; // bytes
    // uint64_t udpMaxPacketNumber = 4294967295U;
    // udpMaxPacketNumber = 4294967295U;

    uint8_t nLinks = 3;
    std::string channelLink0 =
        "{36, 20, BAND_5GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
                                  // index} set to wifi-phy.cc and from
                                  // wifi-phy-operating-channel.cc
    std::string channelLink1 =
        "{40, 20, BAND_5GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
                                  // index} set to wifi-phy.cc and from
                                  // wifi-phy-operating-channel.cc

    std::string channelLink2 =
        "{1, 20, BAND_2_4GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
                                   // index} set to wifi-phy.cc and from
                                   // wifi-phy-operating-channel.cc

    NodeContainer staNodes;
    staNodes.Create(1);
    NodeContainer apNode;
    apNode.Create(1);

    NetDeviceContainer staDevices;
    NetDeviceContainer apDevice;
    WifiMacHelper wifiMac;
    WifiHelper wifiHelper;

    wifiHelper.SetStandard(WIFI_STANDARD_80211be);
    Ssid ssid = Ssid("ns3-80211be");

    Ptr<MultiModelSpectrumChannel> spectrumChannel = CreateObject<MultiModelSpectrumChannel>();
    Ptr<LogDistancePropagationLossModel> lossModel =
        CreateObject<LogDistancePropagationLossModel>();
    spectrumChannel->AddPropagationLossModel(lossModel);

    SpectrumWifiPhyHelper phy(nLinks);
    phy.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
    phy.SetChannel(spectrumChannel);

    NodeContainer staNodes1;
    NetDeviceContainer staDevice1;
    bool flagsta1{true};
    /* 再往这个网络中添加一些仅连接到单个Link上的sta */
    if (flagsta1)
    {
        staNodes1.Create(1);
        // WifiHelper &wifiHelper1 = wifiHelper;
        WifiHelper wifiHelper1;
        wifiHelper1.SetStandard(WIFI_STANDARD_80211be);
        uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
        std::string ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        wifiHelper1.SetRemoteStationManager((int)0,
                                            "ns3::ConstantRateWifiManager",
                                            "DataMode",
                                            StringValue("EhtMcs" + std::to_string(mcs)),
                                            "ControlMode",
                                            StringValue(ctrlRateStr));
        // Ptr<MultiModelSpectrumChannel> spectrumChannel1 =
        // CreateObject<MultiModelSpectrumChannel>(); Ptr<LogDistancePropagationLossModel>
        // lossModel1 =
        //     CreateObject<LogDistancePropagationLossModel>();
        // spectrumChannel1->AddPropagationLossModel(lossModel1);
        Ptr<MultiModelSpectrumChannel> spectrumChannel1 = CreateObject<MultiModelSpectrumChannel>();
        uint8_t nLinks1 = 1;
        SpectrumWifiPhyHelper phy1(nLinks1);
        phy1.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
        phy1.SetChannel(spectrumChannel);
        phy1.Set(0, "ChannelSettings", StringValue(channelLink0));
        WifiMacHelper wifiMac1;
        wifiMac1.SetType("ns3::StaWifiMac", "Ssid", SsidValue(ssid));
        staDevice1 = wifiHelper1.Install(phy1, wifiMac1, staNodes1);
    }
    /* 再往这个网络中添加一些仅连接到单个Link上的sta */

    for (uint8_t i = 0; i < nLinks; i++)
    {
        uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
        std::string ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        wifiHelper.SetRemoteStationManager(i,
                                           "ns3::ConstantRateWifiManager",
                                           "DataMode",
                                           StringValue("EhtMcs" + std::to_string(mcs)),
                                           "ControlMode",
                                           StringValue(ctrlRateStr));
    }

    wifiMac.SetType("ns3::StaWifiMac", "Ssid", SsidValue(ssid));
    phy.Set(0, "ChannelSettings", StringValue(channelLink0));
    phy.Set(1, "ChannelSettings", StringValue(channelLink1));
    phy.Set(2, "ChannelSettings", StringValue(channelLink2));
    staDevices = wifiHelper.Install(phy, wifiMac, staNodes);

    if (flagsta1)
    {
        staDevices.Add(staDevice1);
        staNodes.Add(staNodes1);
    }

    // 查看每隔staDevice的所有mac地址
    std::cout << "staDevices info :" << std::endl;
    for (u_int16_t i = 0; i < staDevices.GetN(); i++)
    {
        const Ptr<WifiNetDevice> device = DynamicCast<WifiNetDevice>(staDevices.Get(i));
        Ptr<WifiMac> mac = device->GetMac();
        for (uint8_t linkId = 0; linkId < std::max<uint8_t>(device->GetNPhys(), 1); ++linkId)
        {
            auto fem = mac->GetFrameExchangeManager(linkId);
            std::cout << "staDevice " << i << " linkId " << std::to_string(linkId)
                      << " mac address: " << fem->GetAddress() << std::endl;
        }
    }

    // 考虑到仅仅只有一个STA，因此就不使用多用户ACK机制

    wifiMac.SetType("ns3::ApWifiMac",
                    "Ssid",
                    SsidValue(ssid),
                    "EnableBeaconJitter",
                    BooleanValue(false));
    apDevice = wifiHelper.Install(phy, wifiMac, apNode);

    RngSeedManager::SetSeed(1);
    RngSeedManager::SetRun(1);
    int64_t streamNumber = 100;
    streamNumber += wifiHelper.AssignStreams(apDevice, streamNumber);
    streamNumber += wifiHelper.AssignStreams(
        staDevices,
        streamNumber); // 既为不同种类的device分配了不同的随机数种子，又保证了在不同的仿真实验中，随机数种子相同，从而保证了实验的可重复性。

    // Set guard interval and MPDU buffer size
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/GuardInterval",
                TimeValue(NanoSeconds(gi))); // 保护间隔，即每个OFDM符号之间的间隔

    // mobility.
    // 设置移动模型，即设置移动节点的位置，此处仅是静态模型
    MobilityHelper mobility;
    Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator>();

    positionAlloc->Add(Vector(0.0, 0.0, 0.0));
    positionAlloc->Add(Vector(distance, 0.0, 0.0));
    mobility.SetPositionAllocator(positionAlloc);

    mobility.SetMobilityModel("ns3::ConstantPositionMobilityModel");

    mobility.Install(apNode);
    mobility.Install(staNodes);

    /* Internet stack*/
    InternetStackHelper stack;
    stack.Install(apNode);
    stack.Install(staNodes);

    Ipv4AddressHelper address;
    address.SetBase("192.168.1.0", "255.255.255.0");
    Ipv4InterfaceContainer staNodeInterfaces;
    Ipv4InterfaceContainer apNodeInterface;

    staNodeInterfaces = address.Assign(staDevices);
    apNodeInterface = address.Assign(apDevice);

    /* Setting applications */
    ApplicationContainer serverApp;
    auto serverNodes = apNode;
    Ipv4InterfaceContainer serverInterfaces;
    NodeContainer clientNodes;

    serverInterfaces.Add(apNodeInterface.Get(0));
    clientNodes.Add(staNodes);

    for (uint16_t port = 9; port < 19; port++)
    {
        // uint16_t port = 9;
        UdpServerHelper server(port);
        serverApp = server.Install(serverNodes.Get(0));
        serverApp.Start(Seconds(0.0));
        serverApp.Stop(Seconds(simulationTime + 1));

        UdpClientHelper client(serverInterfaces.GetAddress(0), port);
        client.SetAttribute("MaxPackets", UintegerValue(4294967295U));
        client.SetAttribute("Interval", TimeValue(Time("0.00001"))); // packets/s
        client.SetAttribute("PacketSize", UintegerValue(payloadSize));
        ApplicationContainer clientApp = client.Install(clientNodes);
        clientApp.Start(Seconds(1.0));
        clientApp.Stop(Seconds(simulationTime + 1));
    }

    // 打印udp服务器端及客户端的mac地址和ip地址
    std::cout << "UDP info :" << std::endl;
    std::cout << "Server IP address: " << serverInterfaces.GetAddress(0) << std::endl;
    std::cout << "Server MAC address: " << staDevices.Get(0)->GetAddress() << std::endl;
    std::cout << "Client IP address: " << apNodeInterface.GetAddress(0) << std::endl;
    std::cout << "Client MAC address: " << apDevice.Get(0)->GetAddress() << std::endl;

    std::cout << "NetDeviceContainer info :" << std::endl;
    std::cout << "Number of apDevice " << apDevice.GetN() << std::endl;
    std::cout << "Number of staDevices " << staDevices.GetN() << std::endl;

    std::cout << "apNode NetDevices info :" << std::endl;
    for (u_int16_t i = 0; i < apNode.Get(0)->GetNDevices(); i++)
    {
        std::cout << "apNode NetDevice " << i
                  << " mac address: " << apNode.Get(0)->GetDevice(i)->GetAddress() << std::endl;
    }

    std::cout << "staNode NetDevices info :" << std::endl;
    for (u_int16_t i = 0; i < staNodes.Get(0)->GetNDevices(); i++)
    {
        std::cout << "staNode NetDevice " << i
                  << " mac address: " << staNodes.Get(0)->GetDevice(i)->GetAddress() << std::endl;
    }

    std::string pcapFileName = "EhtStaApLinkLinkLink-ap";
    phy.EnablePcap(pcapFileName, apDevice, true);
    std::string pcapFileName1 = "EhtStaApLinkLinkLink-sta";
    phy.EnablePcap(pcapFileName1, staDevices, true);

    std::cout << "MCS value"
              << "\t\t"
              << "Channel width"
              << "\t\t"
              << "GI"
              << "\t\t\t"
              << "Throughput" << '\n';

    Simulator::Stop(Seconds(simulationTime + 1));
    Simulator::Run();

    // cumulative number of bytes received by each server application
    std::vector<uint64_t> cumulRxBytes(1, 0);
    cumulRxBytes = StaApLinkLinkGetRxBytes(true, serverApp, payloadSize);
    uint64_t rxBytes =
        std::accumulate(cumulRxBytes.cbegin(), cumulRxBytes.cend(), 0); // 向量中元素累加
    double throughput = (rxBytes * 8) / (simulationTime * 1000000.0);   // Mbit/s

    Simulator::Destroy();
    std::cout << mcs << "\t\t\t"
              << "20 MHz\t\t\t" << gi << " ns\t\t\t" << throughput << " Mbit/s" << std::endl;

    return 0;
}